d434f6
@@ -42,6 +42,7 @@
 import org.apache.hadoop.hive.ql.exec.SelectOperator;
 import org.apache.hadoop.hive.ql.exec.TableScanOperator;
 import org.apache.hadoop.hive.ql.exec.UDF;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
 import org.apache.hadoop.hive.ql.exec.Utilities;
 import org.apache.hadoop.hive.ql.lib.Node;
 import org.apache.hadoop.hive.ql.lib.NodeProcessor;
@@ -65,11 +66,15 @@
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredJavaObject;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBaseCompare;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFBridge;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPAnd;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNot;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotNull;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNull;
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPOr;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDFWhen;
 import org.apache.hadoop.hive.serde.serdeConstants;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
@@ -79,9 +84,11 @@
 import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector.PrimitiveCategory;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
+import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableConstantBooleanObjectInspector;
 import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
+import org.apache.hadoop.io.BooleanWritable;
 
 import com.google.common.collect.ImmutableSet;
 
@@ -199,10 +206,11 @@
public static ExprNodeDesc foldExpr(ExprNodeGenericFuncDesc funcDesc) {
    * @param op processing operator
    * @param propagate if true, assignment expressions will be added to constants.
    * @return fold expression
+   * @throws UDFArgumentException
    */
   private static ExprNodeDesc foldExpr(ExprNodeDesc desc, Map<ColumnInfo, ExprNodeDesc> constants,
       ConstantPropagateProcCtx cppCtx, Operator<? extends Serializable> op, int tag,
-      boolean propagate) {
+      boolean propagate) throws UDFArgumentException {
     if (desc instanceof ExprNodeGenericFuncDesc) {
       ExprNodeGenericFuncDesc funcDesc = (ExprNodeGenericFuncDesc) desc;
 
@@ -356,7 +364,7 @@
private static ExprNodeColumnDesc getColumnExpr(ExprNodeDesc expr) {
     return (expr instanceof ExprNodeColumnDesc) ? (ExprNodeColumnDesc)expr : null;
   }
 
-  private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc> newExprs) {
+  private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc> newExprs) throws UDFArgumentException {
     if (udf instanceof GenericUDFOPAnd) {
       for (int i = 0; i < 2; i++) {
         ExprNodeDesc childExpr = newExprs.get(i);
@@ -407,6 +415,77 @@
private static ExprNodeDesc shortcutFunction(GenericUDF udf, List<ExprNodeDesc>
       }
     }
 
+    if (udf instanceof GenericUDFWhen) {
+      if (!(newExprs.size() == 2 || newExprs.size() == 3)) {
+        // In general, when can have unlimited # of branches,
+        // we currently only handle either 1 or 2 branch.
+        return null;
+      }
+      ExprNodeDesc thenExpr = newExprs.get(1);
+      if (thenExpr instanceof ExprNodeNullDesc && (newExprs.size() == 2 || newExprs.get(2) instanceof ExprNodeNullDesc)) {
+        return thenExpr;
+      }
+      ExprNodeDesc elseExpr = newExprs.size() == 3 ? newExprs.get(2) :
+        new ExprNodeConstantDesc(newExprs.get(2).getTypeInfo(),null);
+
+      ExprNodeDesc whenExpr = newExprs.get(0);
+      if (whenExpr instanceof ExprNodeConstantDesc) {
+        Boolean whenVal = (Boolean)((ExprNodeConstantDesc) whenExpr).getValue();
+        return (whenVal == null || Boolean.FALSE.equals(whenVal)) ? elseExpr : thenExpr;
+      }
+
+      if (thenExpr instanceof ExprNodeConstantDesc && elseExpr instanceof ExprNodeConstantDesc) {
+        ExprNodeConstantDesc constThen = (ExprNodeConstantDesc) thenExpr;
+        ExprNodeConstantDesc constElse = (ExprNodeConstantDesc) elseExpr;
+        Object thenVal = constThen.getValue();
+        Object elseVal = constElse.getValue();
+        if (thenVal == null) {
+          return elseVal == null ? thenExpr : null;
+        } else if(thenVal.equals(elseVal)){
+          return thenExpr;
+        } else if (thenVal instanceof Boolean && elseVal instanceof Boolean) {
+          return Boolean.TRUE.equals(thenVal) ? newExprs.get(0) :
+            ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNot(), newExprs.subList(0, 1));
+        } else {
+          return null;
+        }
+      }
+    }
+    if (udf instanceof GenericUDFCase) {
+      // HIVE-9644 Attempt to fold expression like :
+      // where (case ss_sold_date when '1998-01-01' then 1=1 else null=1 end);
+      // where ss_sold_date= '1998-01-01' ;
+      if (!(newExprs.size() == 3 || newExprs.size() == 4)) {
+        // In general case can have unlimited # of branches,
+        // we currently only handle either 1 or 2 branch.
+        return null;
+      }
+      ExprNodeDesc thenExpr = newExprs.get(2);
+      if (thenExpr instanceof ExprNodeNullDesc && (newExprs.size() == 3 || newExprs.get(3) instanceof ExprNodeNullDesc)) {
+        return thenExpr;
+      }
+
+      ExprNodeDesc elseExpr = newExprs.size() == 4 ? newExprs.get(3) :
+        new ExprNodeConstantDesc(newExprs.get(2).getTypeInfo(),null);
+
+      if (thenExpr instanceof ExprNodeConstantDesc && elseExpr instanceof ExprNodeConstantDesc) {
+        ExprNodeConstantDesc constThen = (ExprNodeConstantDesc) thenExpr;
+        ExprNodeConstantDesc constElse = (ExprNodeConstantDesc) elseExpr;
+        Object thenVal = constThen.getValue();
+        Object elseVal = constElse.getValue();
+        if (thenVal == null) {
+          return elseVal == null ? thenExpr : null;
+        } else if(thenVal.equals(elseVal)){
+          return thenExpr;
+        } else if (thenVal instanceof Boolean && elseVal instanceof Boolean) {
+          return Boolean.TRUE.equals(thenVal) ? ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPEqual(), newExprs.subList(0, 2)) :
+            ExprNodeGenericFuncDesc.newInstance(new GenericUDFOPNotEqual(), newExprs.subList(0, 2));
+        } else {
+          return null;
+        }
+      }
+    }
+
     return null;
   }
 
